# https://leetcode.com/problems/minimize-malware-spread-ii/

import collections

# union find, TC:O(N^2*alpha(N)), SC:O(N)
def minMalwareSpread(graph: list[list[int]], initial: list[int]) -> int:
    n = len(graph)
    parents = [i for i in range(n)]
    size = [1] * n
    initial_set = set(initial)

    def find(x):
        if x != parents[x]:
            parents[x] = find(parents[x])
        return parents[x]

    def union(x, y):
        px, py = find(x), find(y)
        parents[py] = px
        size[px] += size[py]

    # TC:O(N^2*alpha(N))
    for i in range(n):
        for j in range(i + 1, n):
            # don't union node in initial
            if graph[i][j] == 0 or i in initial_set or j in initial_set: continue
            union(i, j)  # two nodes, i is parent
    connections = collections.Counter()
    i_to_comp = {}
    for i in initial:
        conn_parents = set()  # connected parents
        for j in range(n):
            if j in initial_set or graph[i][j] == 0: continue  # do nothing
            conn_parents.add(find(j))
        i_to_comp[i] = conn_parents
        # add all conn_parents one connection
        for c in conn_parents:  # infected by i
            connections[c] += 1

    res = initial[0]
    score = 0
    for i in initial:
        tempScore = 0
        conn_parents = i_to_comp[i]
        for c in conn_parents:
            if connections[c] == 1:  # add to score if only infect by i
                tempScore += size[find(c)]
        if tempScore > score or (tempScore == score and i < res):  # update res
            res = i
            score = tempScore
    return res


# concise union find, TC:O(N^2*alpha(N)), SC:O(N)
def minMalwareSpread2(graph: list[list[int]], initial: list[int]) -> int:
    n = len(graph)
    parents = [i for i in range(n)]
    size = [1] * n
    initial_set = set(initial)
    clean = [i for i in range(n) if i not in initial_set]

    def find(x):
        if x != parents[x]:
            parents[x] = find(parents[x])
        return parents[x]

    def union(x, y):
        px, py = find(x), find(y)
        if px == py: return  # do nothing
        parents[py] = px
        size[px] += size[py]

    # union clean pairs
    # TC:O(N^2*alpha(N))
    for i in clean:
        for j in clean:
            # don't union node in initial
            if graph[i][j] == 0: continue
            union(i, j)  # two nodes, i is parent
    # build connections counter to find unique infection
    connections = collections.Counter()
    i_to_comp = {}
    for i in initial:
        i_to_comp[i] = set(find(j) for j in clean if graph[i][j] == 1)  # connected parents
        # add all conn_parents one connection
        for c in i_to_comp[i]:  # infected by i
            connections[c] += 1
    return min(initial, key=lambda x: [-sum(size[find(i)] for i in i_to_comp[x] if connections[i] == 1), x])



graph = [[1,0,0,0,0,0,0,0,0],[0,1,0,0,0,0,0,0,1],[0,0,1,0,0,0,0,0,0],[0,0,0,1,0,0,0,0,1],[0,0,0,0,1,0,1,1,1],[0,0,0,0,0,1,0,0,1],[0,0,0,0,1,0,1,1,0],[0,0,0,0,1,0,1,1,0],[0,1,0,1,1,1,0,0,1]]
initial = [8,4,2,0]

res1 = minMalwareSpread(graph, initial)
res2 = minMalwareSpread2(graph, initial)